The present workshop is based on the tutorial "Modèle Hiérarchique avec Stan" of Matthieu Authier & Eric Parent.

1 The data

In this workshop, we use flowering date data collected between 1978 and 2016 and published in Wenden et al. (2016). Data can be downloaded in this driad repository. This dataset contains flowering dates of 9,691 indivuals/clones of Prunus avium in Europe.

Below is a figure from Wenden et al. (2016) showing the 25 studied sites in 11 European countries. Flowering dates were recorded in 12 sites. Size of the circle is proportional to the number of cultivars recorded in each site.

dataSakura <- read_excel("../data/Sweet_cherry_phenology_data_1978-2015.xlsx", sheet = 1)

dataSakura <- dataSakura[1:1000,] %>%                 # keep only 1000 individuals (to shorten the model running time)
  dplyr::rename(Flowering="Full Flowering") %>%       # response variable: the date of flowering
  filter(!is.na(Flowering),!is.na(Plantation)) %>%    # remove missing values
  dplyr::mutate(Age = Year - Plantation,              # create the "age" variable
         Age = ifelse(Age > 14, 14, Age)) %>% 
  dplyr::select(Site,Age,Cultivar,Flowering)          # select the columns we are going to use

# Show the first 10 lines of the dataset
dataSakura[1:10,] %>%
  kable(digits=3) %>%
  kable_styling(font_size=12,
                bootstrap_options = c("striped","hover", "condensed"), full_width = F)
Site Age Cultivar Flowering
Montauban 10 Burlat 73
Montauban 10 Regina 92
Montauban 10 Satin Sumele 87
Montauban 10 Summit 93
Montauban 10 Bellise Bedel 83
Montauban 9 Ferlizac 78
Montauban 10 Ferlizac 84
Montauban 9 Fermina 90
Montauban 10 Fermina 90
Montauban 9 Fertille 83

Variation in flowering date with tree age (14 age classes):

dataSakura %>%
  group_by(Age) %>% 
  summarize(Effectif = n(),
            Flowering_mean = round(mean(Flowering, na.rm = TRUE), 1)) %>% 
  kable() %>%
  kable_styling(font_size=12,
                bootstrap_options = c("striped","hover", "condensed"), full_width = F)
Age Effectif Flowering_mean
1 2 86.5
2 81 91.7
3 91 92.6
4 103 93.9
5 99 90.8
6 114 93.9
7 104 92.7
8 82 92.5
9 63 94.1
10 53 90.9
11 25 96.2
12 26 93.5
13 11 90.6
14 5 108.2

Variation in flowering date by site (12 sites):

dataSakura %>% 
  group_by(Site) %>% 
  summarize(Effectif = n(),
            Flowering_mean = round(mean(Flowering, na.rm = TRUE), 1)) %>% 
  kable() %>%
  kable_styling(font_size=12,
                bootstrap_options = c("striped","hover", "condensed"), full_width = F)
Site Effectif Flowering_mean
Balandran 198 91.2
Carpentras 322 91.4
Montauban 79 88.9
St Epain 221 97.5
Toulenne 39 94.1

2 Baseline statistical model

2.1 Model equation

We start with a simple model in which we aim to model the flowering date \(y_{ijk}\) of each individual \(i\) as a function of its age \(j\) and its site \(k\), such as:

\[\begin{align} y_{ijk} & \sim \mathcal{N}(\mu_{ijk},\sigma) \tag*{Likelihood}\\[3pt] \mu_{ijk} & = \beta_0 + \alpha_j + \delta_k \tag*{Linear model}\\[3pt] \beta_0 & \sim \mathcal{N}(\mu_y, 10) \tag*{Global intercept prior}\\[3pt] \alpha_j & \sim \mathcal{N}(0,\sigma_{age})\tag*{Distribution of varying age intercepts}\\[3pt] \alpha_k & \sim \mathcal{N}(0,\sigma_{site}) \tag*{Distribution of varying site intercepts}\\ \end{align}\]

We want to specify the priors for \(\sigma\), \(\sigma_{age}\) and \(\sigma_{site}\). For that, we partition the total variance \(\sigma_{tot}\) as follows:

\[\begin{align} \sigma^2_{tot} & = \sigma^2 + \sigma^2_{age} + \sigma^2_{site}\\[3pt] \sigma & = \sigma_{tot} \times \sqrt{\pi_1}\\[3pt] \sigma_{age} & = \sigma_{tot} \times \sqrt{\pi_2}\\[3pt] \sigma_{site} & = \sigma_{tot} \times \sqrt{\pi_3}\\[3pt] \end{align}\]

with \(\sum_{l=1}^3\pi_l = 1\) (see the unit simplex in stan) and \(\sigma_{tot} \sim \mathcal{S}^+(0,1,3)\) (student prior with 3 degrees of freedom).

This model is an ANOVA with 2 factors (age & site).

2.2 Stan code

/*----------------------- Data --------------------------*/
data {
  int<lower = 1> n_obs;                               // Total number of observations
  int<lower = 1> n_age;                               // Number of different age classes
  int<lower = 1> n_site;                              // Number of different sites
  vector[n_obs] FLOWERING;                            // Response variable (flowering dates)
  int<lower = 1, upper = n_age> AGE[n_obs];           // Age variable
  int<lower = 1, upper = n_site> SITE[n_obs];         // Site variable
  real prior_location_beta0;
  real<lower = 0.0> prior_scale_beta0;
}
/*----------------------- Parameters --------------------------*/
parameters {
  simplex[3] pi;                                      // unit complex specifying that the sum of its elements equal to one.
  real beta0;                                         // global intercept
  real<lower = 0.0> sigma_tot;                        // Total standard deviation
  vector[n_age] alpha;                                // Age intercepts
  vector[n_site] delta;                               // Site intercepts
}
/*------------------- Transformed Parameters --------------------*/
transformed parameters {
  real sigma;                                         // Residual standard deviation
  real sigma_age;                                     // Standard deviation of the age intercepts
  real sigma_site;                                    // Standard deviation of the site intercepts
  vector[n_obs] mu;                                   // linear predictor
  
  sigma = sqrt(pi[1]) * sigma_tot;
  sigma_age = sqrt(pi[2]) * sigma_tot;
  sigma_site = sqrt(pi[3]) * sigma_tot;
  mu = rep_vector(beta0, n_obs) + alpha[AGE] + delta[SITE];
}
/*----------------------- Model --------------------------*/
model {
  // Priors
  beta0 ~ normal(prior_location_beta0, prior_scale_beta0);  // Prior of the global intercept
  sigma_tot ~ student_t(3, 0.0, 1.0);                       // Prior of the total standard deviation
  alpha ~ normal(0.0, sigma_age);                           // Prior of the age intercepts
  delta ~ normal(0.0, sigma_site);                          // Prior of the site intercepts

  // Likelihood
  FLOWERING ~ normal(mu, sigma);
}
/*----------------- Generated Quantities ------------------*/
generated quantities {
  vector[n_obs] log_lik;                    // Log-likelihood
  vector[n_obs] y_rep;                      // posterior predictive check

  for(i in 1:n_obs) {
    log_lik[i] = normal_lpdf(FLOWERING[i]| mu[i], sigma);  // log probability density function
    y_rep[i] = normal_rng(mu[i], sigma);                   // prediction from posterior
  }
}

2.3 Running the model

Input data:

list.baseline.model = list(n_obs = nrow(dataSakura),
                 n_age = length(unique(dataSakura$Age)),
                 n_site = length(unique(dataSakura$Site)),
                 FLOWERING = dataSakura$Flowering,
                 AGE = dataSakura$Age,
                 SITE = as.numeric(factor(dataSakura$Site, levels = unique(dataSakura$Site))),
                 prior_location_beta0 = mean(dataSakura$Flowering),
                 prior_scale_beta0 = 10)

Sampling:

fit.baseline.model <- sampling(baseline.model, 
                               data = list.baseline.model, 
                               pars = c("beta0", "alpha", "delta", 
                                        "sigma", "sigma_age", "sigma_site", "sigma_tot", 
                                        "pi", "y_rep", "log_lik"),
                               save_warmup = F, 
                               iter = 2000, 
                               chains = 4, cores = 4,thin=1)
loo.baseline.model <- loo::loo(fit.baseline.model) # to compare model predictive ability

2.4 Model outputs

2.4.1 Checking parameter convergence

stan_rhat(fit.baseline.model)

That's ok!

It would be better to check also the chain convergence, the effective sammple size and the autocorrelation, but we will skip these steps here! More details here: https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html.

2.4.2 Parameter estimates

Let's look at the parameter estimates of the standard deviation for the sites (\(\sigma_{site}\)), the age (\(\sigma_{age}\)), the residuals (\(\sigma\)), the total standard deviation (\(\sigma_{tot}\)) and the relative importance of each variance component (i.e. proportion of the total variance explained by each component, i.e. site, age and residuals).

Here is the coefficients table:

print(fit.baseline.model, digits = 3, pars = c("beta0", "sigma", "sigma_age", "sigma_site", "sigma_tot", "pi"))
## Inference for Stan model: 62d24212c46de9892dd578dab95b34f3.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##              mean se_mean    sd   2.5%    25%    50%    75%  97.5% n_eff  Rhat
## beta0      92.774   0.074 1.892 88.917 91.690 92.792 93.866 96.443   654 1.009
## sigma       7.079   0.003 0.175  6.743  6.958  7.076  7.197  7.430  4107 1.000
## sigma_age   1.530   0.023 0.676  0.597  1.050  1.409  1.860  3.205   839 1.007
## sigma_site  3.772   0.040 1.340  1.956  2.824  3.515  4.440  7.027  1145 1.007
## sigma_tot   8.270   0.024 0.757  7.367  7.772  8.105  8.548 10.183  1016 1.006
## pi[1]       0.747   0.003 0.109  0.486  0.686  0.764  0.828  0.901  1163 1.007
## pi[2]       0.040   0.001 0.036  0.005  0.016  0.029  0.051  0.138   945 1.007
## pi[3]       0.213   0.003 0.109  0.068  0.131  0.190  0.271  0.484  1205 1.008
## 
## Samples were drawn using NUTS(diag_e) at Sat Jun  5 07:42:25 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

Plotting the credible intervals of the parameters of interest

First, the relative importance of each variance component: \(\pi_1\) for the residuals, \(\pi_2\) for the age and \(\pi_3\) for the sites.

fit.baseline.model %>%  mcmc_intervals(regex_pars = "^pi",
                    prob=0.95,
                    prob_outer=0.99,
                    point_est = "median") +  theme_bw() +
  theme(axis.text = element_text(size=16))

The standard deviations:

lower <- function(x, alpha = 0.8) { coda::HPDinterval(coda::as.mcmc(x), prob = alpha)[1] }
upper <- function(x, alpha = 0.8) { coda::HPDinterval(coda::as.mcmc(x), prob = alpha)[2] }
get_summary <- function(x, alpha = 0.8) { c(mean(x), sd(x), coda::HPDinterval(coda::as.mcmc(x), prob = alpha)) }

summary_anova <- as.data.frame(
  do.call('rbind', lapply(c("sigma", "sigma_age", "sigma_site","sigma_tot"),
                          function(param) {
                            get_summary(as.numeric(rstan::extract(fit.baseline.model, param)[[1]]))
                            }
                          )
          )
)

names(summary_anova) <- c("mean", "se", "lower", "upper")
summary_anova$component <- c("residual", "age", "site","total")

summary_anova %>% 
  mutate(component = factor(component, levels = c("residual", "age", "site","total")[order(mean)])) %>% 
  ggplot(aes(x = component, y = mean)) +
  geom_linerange(aes(x = component, ymin = lower, ymax = upper)) +
  geom_point(size = 2) +
  ylab("Estimate") + xlab("Source of variation") +
  coord_flip() +
  theme_bw()

# For the graph, we could also have used:
# fit.baseline.model %>%  mcmc_intervals(regex_pars = "^sigma",
#                     prob=0.95,
#                     prob_outer=0.99,
#                     point_est = "median") +  theme_bw() +
#   theme(axis.text = element_text(size=16))

As sites seem to considerably impact the total variance, we can display the parameters \(\delta_k\):

freq <- dataSakura %>% 
  group_by(Site) %>% 
  summarize(effectif = n(),
            flowering = mean(Flowering, na.rm = TRUE)
            ) 

moyenne = with(freq, sum(effectif * flowering) / sum(effectif))
freq[6, 1] = "Average"
freq[6, 2] = mean(freq$effectif)
freq[6, 3] = moyenne

post_site <- as.data.frame(t(apply(matrix(rep(rstan::extract(fit.baseline.model, "beta0")$beta0, each = length(unique(dataSakura$Site))), ncol = length(unique(dataSakura$Site)), byrow = TRUE) + rstan::extract(fit.baseline.model, "delta")$delta, 2, get_summary)))
names(post_site) <- c("mean", "se", "lower", "upper")
post_site$where <- c(unique(dataSakura$Site))  #, "Average")
post_site <- cbind(post_site,
                   freq[match(post_site$where, freq$Site), c('flowering', 'effectif')]
                   )

post_site %>% 
  mutate(where = factor(where, levels = c(unique(dataSakura$Site), "Average")[order(mean)])) %>% 
  ggplot(aes(x = where, y = mean)) +
  geom_linerange(aes(x = where, ymin = lower, ymax = upper)) +
  geom_point(size = 2) +
  geom_point(aes(x = where, y = flowering, size = effectif), color = 'red', alpha = 0.3) +
  scale_y_continuous(name = "Estimate (days)", breaks = 95 -10:10) + 
  xlab("Site") +
  coord_flip() +
  theme_bw()

2.4.3 Posterior predictive checks

ppc_dens_overlay(y = dataSakura$Flowering,
                 as.matrix(fit.baseline.model, pars = "y_rep")[1:50, ]) +
  theme_bw() + 
  theme(legend.text=element_text(size=25), legend.title=element_text(size=18),
        axis.text = element_text(size=18), legend.position = c(0.8,0.6))

3 Writing our own likelihood function

To write our own likelihood function, we will use the target += function, which allows to directly increments the log density of the posterior up to an additive constant.

From Bob Carpenter's comment in Stackoverflow: target += u adds u to the target log density. The target density is the density from which the sampler samples and it needs to be equal to the joint density of all the parameters given the data up to a constant.

But first, we have to mathematically write our own likelihood!

3.1 Model equation

Let's assume there are two genetically distinct types of individuals that can be differentiated based on their flowering date : early or late flowering. We want to identify individuals with the genetic potential of flowering earlier.

Let's \(p\) be the probability of a late individual \(i\) and \(1-p\) the probability of an early individual.

Then, the model becomes: \[\begin{align} y_{ijk} & \sim \mathcal{N}(\mu_{ijk}^l,\sigma) \tag*{Likelihood}\\[3pt] \mu_{ijk}^l & = \beta_l + \alpha_j + \delta_k \tag*{Linear model}\\[3pt] \\ \end{align}\]

with \(l \in \{1,2\}\). Therefore, we introduced a supplementary discrete latent variable: \(z_{ijk} \sim \mathcal{B}(p)\) which models the state (early \(l=1\) or late \(l=2\)) according to probability \(p\).

As a consequence, the likelihood is: \[\begin{align} \mathcal{L}(y_{ijk}) = (1-p) \times \mathcal{N}(\beta_1 + \alpha_j + \delta_k, \sigma) + p \times \mathcal{N}(\beta_2 + \alpha_j + \delta_k, \sigma) \\ \end{align}\] Thus, the log-likelihood is: \[\begin{align} l(y_{ijk}) = \log{[(1-p) \times \mathcal{N}(\beta_1 + \alpha_j + \delta_k, \sigma) + p \times \mathcal{N}(\beta_2 + \alpha_j + \delta_k, \sigma)]} \\ \end{align}\]

3.2 Stan code

Due to the presence of a discrete variable \(z_{ijk}\), the likelihood is then implemented using the target function instead of \(\sim\).

First way of doing it, we can directly write the likelihood function in the model block:

/*----------------------- Data --------------------------*/
data {
  int<lower = 1> n_obs;                                              // Total number of observations
  int<lower = 1> n_age;                                              // Number of different age classes
  int<lower = 1> n_site;                                             // Number of different sites
  vector[n_obs] FLOWERING;                                           // Response variable (flowering dates)
  int<lower = 1, upper = n_age> AGE[n_obs];                          // Age variable
  int<lower = 1, upper = n_site> SITE[n_obs];                        // Site variable
  real prior_location_beta0;
  real<lower = 0.0> prior_scale_beta0;
  real prior_location_diff;
  real<lower = 0.0> prior_scale_diff;
}

/*----------------------- Parameters --------------------------*/
parameters {
  real<lower = 0.0, upper = 1.0> p;                  // proba (early or late flowering)
  simplex[3] pi;                                                     // unit complex specifying that the sum of its elements equal to one. 
  real beta0;                                                        // global intercept
  real<lower = 0.0> sigma_tot;                                       // Total standard deviation
  vector[n_age] alpha;                                               // Age intercepts
  vector[n_site] delta;                                              // Site intercepts
  real<lower = 0> difference;                                        // difference between beta_1 and beta_2
}


/*------------------- Transformed Parameters --------------------*/
transformed parameters {
  real sigma;                                                        // Residual standard deviation
  real sigma_age;                                                    // Standard deviation of the age intercepts
  real sigma_site;                                                   // Standard deviation of the site intercepts
  vector[n_obs] mu[2];                                               // linear predictor
  vector[2] beta;
  beta[1] = prior_location_beta0 + beta0 * prior_scale_beta0;
  beta[2] = beta[1] + difference;
  sigma = sqrt(pi[1]) * sigma_tot;
  sigma_age = sqrt(pi[2]) * sigma_tot;
  sigma_site = sqrt(pi[3]) * sigma_tot;
  mu[1] = rep_vector(beta[1], n_obs) + alpha[AGE] + delta[SITE];
  mu[2] = rep_vector(beta[2], n_obs) + alpha[AGE] + delta[SITE];
}

/*----------------------- Model --------------------------*/
model {
  // Priors
  beta0 ~ normal(0.0, 1.0);
  difference ~ normal(prior_location_diff, prior_scale_diff);
  sigma_tot ~ student_t(3, 0.0, 1.0);                                // Prior of the total standard deviation
  alpha ~ normal(0.0, sigma_age);                                    // Prior of the age intercepts
  delta ~ normal(0.0, sigma_site);                                   // Prior of the site intercepts
  
  // Our own likelihood
  for(i in 1:n_obs) {
    target += log_sum_exp(log1m(p) + normal_lpdf(FLOWERING[i] | mu[1, i], sigma), log(p) + normal_lpdf(FLOWERING[i] | mu[2, i], sigma));
  }
}

/*----------------------- Extracting the log-likelihood  --------------------------*/
generated quantities {
  vector[n_obs] log_lik;                                             // Log-likelihood

  for(i in 1:n_obs) {
    log_lik[i] = log_sum_exp(log1m(p) + normal_lpdf(FLOWERING[i] | mu[1, i], sigma), log(p) + normal_lpdf(FLOWERING[i] | mu[2, i], sigma));
  }
}

Second way of doing it, custom-functions can be implemented by using the function block:

/*--------------------- Functions ------------------------*/
functions {
  // for the estimating the log probability density function (lpdf)
  real TwoGaussianMixture_lpdf(real y, real prob, vector location, real scale) {
    real log_pdf[2];
    log_pdf[1] = log1m(prob) + normal_lpdf(y| location[1], scale);
    log_pdf[2] = log(prob) + normal_lpdf(y| location[2], scale);
    return log_sum_exp(log_pdf);
  }
  
  // for the generated quantities (prediction)
  real TwoGaussianMixture_rng(real prob, vector location, real scale) {
    int z;
    z = bernoulli_rng(prob);
    return  z ? normal_rng(location[2], scale) : normal_rng(location[1], scale);
  }
}
/*----------------------- Data --------------------------*/
data {
  int<lower = 1> n_obs;                                              // Total number of observations
  int<lower = 1> n_age;                                              // Number of different age classes
  int<lower = 1> n_site;                                             // Number of different sites
  vector[n_obs] FLOWERING;                                           // Response variable (flowering dates)
  int<lower = 1, upper = n_age> AGE[n_obs];                          // Age variable
  int<lower = 1, upper = n_site> SITE[n_obs];                        // Site variable
  real prior_location_beta0;
  real<lower = 0.0> prior_scale_beta0;
  real prior_location_diff;
  real<lower = 0.0> prior_scale_diff;
}

/*----------------------- Parameters --------------------------*/
parameters {
  real<lower = 0.0, upper = 1.0> p;                                  // proba (early or late flowering)
  simplex[3] pi;                                                     // unit complex specifying that the sum of its elements equal to one.
  real beta0;                                                        // global intercept
  real<lower = 0.0> sigma_tot;                                       // Total standard deviation
  vector[n_age] alpha;                                               // Age intercepts
  vector[n_site] delta;                                              // Site intercepts
  real<lower = 0> difference;                                        // difference between beta_1 and beta_2
}

/*------------------- Transformed Parameters --------------------*/
transformed parameters {
  real sigma;                                                        // Residual standard deviation
  real sigma_age;                                                    // Standard deviation of the age intercepts
  real sigma_site;                                                   // Standard deviation of the site intercepts
  vector[n_obs] mu[2];                                               // linear predictor
  vector[2] beta;
  beta[1] = prior_location_beta0 + beta0 * prior_scale_beta0;
  beta[2] = beta[1] + difference;
  sigma = sqrt(pi[1]) * sigma_tot;
  sigma_age = sqrt(pi[2]) * sigma_tot;
  sigma_site = sqrt(pi[3]) * sigma_tot;
  mu[1] = rep_vector(beta[1], n_obs) + alpha[AGE] + delta[SITE];
  mu[2] = rep_vector(beta[2], n_obs) + alpha[AGE] + delta[SITE];
}

/*----------------------- Model --------------------------*/
model {
  // Priors
  beta0 ~ normal(0.0, 1.0);
  difference ~ normal(prior_location_diff, prior_scale_diff);
  sigma_tot ~ student_t(3, 0.0, 1.0);                                // Prior of the total standard deviation
  alpha ~ normal(0.0, sigma_age);                                    // Prior of the age intercepts
  delta ~ normal(0.0, sigma_site);                                   // Prior of the site intercepts

  // Our own likelihood
  for(i in 1:n_obs) {
   FLOWERING[i] ~ TwoGaussianMixture(p, to_vector(mu[1:2, i]), sigma);
  }
}

/*----------------- Generated Quantities ------------------*/
generated quantities {
  vector[n_obs] log_lik;                                             // Log-likelihood
  vector[n_obs] y_rep;                                               // posterior predictive check

  for(i in 1:n_obs) {
   log_lik[i] = log_sum_exp(log1m(p) + normal_lpdf(FLOWERING[i] | mu[1, i], sigma),
   log(p) + normal_lpdf(FLOWERING[i] | mu[2, i], sigma));
   y_rep[i] = TwoGaussianMixture_rng(p, to_vector(mu[1:2, i]), sigma);
  }
}

3.3 Running the model

Input data:

listMix.stan = list(n_obs = nrow(dataSakura),
                 n_age = length(unique(dataSakura$Age)),
                 n_site = length(unique(dataSakura$Site)),
                 FLOWERING = dataSakura$Flowering,
                 AGE = dataSakura$Age,
                 SITE = as.numeric(factor(dataSakura$Site, levels = unique(dataSakura$Site))),
                 prior_location_beta0 = mean(dataSakura$Flowering),
                 prior_scale_beta0 = 10,
                 prior_location_diff = 7,
                 prior_scale_diff = 3)

Sampling:

fit.mixTarget.model <- sampling(mixTarget.code, 
                               data = listMix.stan, 
                               pars = c("p", "beta", "alpha", "delta", 
                                        "sigma", "sigma_age", "sigma_site", "sigma_tot", 
                                        "pi", "difference", "y_rep", "log_lik"),
                               save_warmup = F, 
                               iter = 2000, 
                               chains = 4, 
                               cores = 4,
                               thin=1)
## Warning: There were 1 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## Warning: Examine the pairs() plot to diagnose sampling problems

3.4 Model outputs

Checking parameter convergence:

stan_rhat(fit.mixTarget.model)

Parameter estimations:

print(fit.mixTarget.model, digits = 3, pars = c("p", "beta", "sigma", "sigma_age", "sigma_site", "sigma_tot", "pi", "difference"))
## Inference for Stan model: b992c238037d6f657254e4dba82efd70.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##              mean se_mean    sd   2.5%    25%    50%    75%   97.5% n_eff  Rhat
## p           0.466   0.002 0.082  0.280  0.424  0.470  0.510   0.604  2260 1.000
## beta[1]    88.696   0.071 1.976 84.939 87.459 88.642 89.885  92.761   781 1.012
## beta[2]    97.690   0.067 1.959 93.691 96.526 97.689 98.852 101.512   861 1.010
## sigma       5.486   0.013 0.423  4.865  5.196  5.416  5.686   6.604   992 1.000
## sigma_age   1.241   0.021 0.578  0.387  0.839  1.131  1.527   2.663   732 1.004
## sigma_site  3.586   0.037 1.238  1.917  2.733  3.341  4.156   6.634  1143 1.000
## sigma_tot   6.767   0.029 0.863  5.614  6.172  6.599  7.158   8.923   891 1.000
## pi[1]       0.676   0.004 0.123  0.387  0.602  0.696  0.766   0.862  1229 1.000
## pi[2]       0.040   0.001 0.037  0.003  0.016  0.029  0.052   0.139   917 1.003
## pi[3]       0.284   0.004 0.127  0.099  0.190  0.261  0.355   0.592  1233 1.000
## difference  8.994   0.037 1.142  5.923  8.553  9.228  9.725  10.518   969 1.000
## 
## Samples were drawn using NUTS(diag_e) at Sat Jun  5 07:44:02 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

Comparison of the estimation of the log-likelihood between the baseline and the mix models (using WAIC):

# WAIC
loo::loo_compare(loo::waic(rstan::extract(fit.baseline.model, "log_lik")$log_lik),
                 loo::waic(rstan::extract(fit.mixTarget.model, "log_lik")$log_lik))
## Warning: 
## 1 (0.1%) p_waic estimates greater than 0.4. We recommend trying loo instead.

## Warning: 
## 1 (0.1%) p_waic estimates greater than 0.4. We recommend trying loo instead.
##        elpd_diff se_diff
## model2  0.0       0.0   
## model1 -1.8       2.7
# LOO-CV
loo.mixTarget.model <- loo(fit.mixTarget.model)
loo::loo_compare(loo.baseline.model,loo.mixTarget.model)
##        elpd_diff se_diff
## model2  0.0       0.0   
## model1 -1.8       2.7

Importance of each factor (age, site, others) in the total variance:

summary_anova <- as.data.frame(
  do.call('rbind', lapply(c("sigma", "sigma_age", "sigma_site","sigma_tot"),
                          function(param) {
                            get_summary(as.numeric(rstan::extract(fit.mixTarget.model, param)[[1]]))
                            }
                          )
          )
)

names(summary_anova) <- c("mean", "se", "lower", "upper")
summary_anova$component <- c("residual", "age", "site","total")

summary_anova %>% 
  mutate(component = factor(component, levels = c("residual", "age", "site","total")[order(mean)])) %>% 
  ggplot(aes(x = component, y = mean)) +
  geom_linerange(aes(x = component, ymin = lower, ymax = upper)) +
  geom_point(size = 2) +
  ylab("Estimate") + xlab("Source of variation") +
  coord_flip() +
  theme_bw()

4 Including covariates

We are going to include the cultivars (100 cultivars) as predictors for estimating the latent state \(z_{ijk}\).

4.1 Model equation

We state that the latent discrete variable \(z_{ijk}\) follows a logistic regression, which can be written as: \[\begin{align} z_{ijk} &\sim \mathcal{B}(p_i) \\ logit(p_i) &= \eta_0 + \eta_{Cultivar} \\ \end{align}\]

4.2 Stan code

/*--------------------- Functions ------------------------*/
functions {
  real TwoGaussianMixture_lpdf(real y, real prob, vector location, real scale) {
    real log_pdf[2];
    log_pdf[1] = log1m(prob) + normal_lpdf(y| location[1], scale);
    log_pdf[2] = log(prob) + normal_lpdf(y| location[2], scale);
    return log_sum_exp(log_pdf);
  }
  real TwoGaussianMixture_rng(real prob, vector location, real scale) {
    int z;
    z = bernoulli_rng(prob);
    return  z ? normal_rng(location[2], scale) : normal_rng(location[1], scale);
  }
}
/*----------------------- Data --------------------------*/
data {
  int<lower = 1> n_obs;                                              // Total number of observations
  int<lower = 1> n_age;                                              // Number of different age classes
  int<lower = 1> n_site;                                             // Number of different sites
  int<lower = 1> n_cultivar;                                         // Number of different cultivars
  vector[n_obs] FLOWERING;                                           // Response variable (flowering dates)
  int<lower = 1, upper = n_age> AGE[n_obs];                          // Age variable
  int<lower = 1, upper = n_site> SITE[n_obs];                        // Site variable
  int<lower = 1, upper = n_cultivar> CULTIVAR[n_obs];                // Cultivar variable
  real prior_location_beta0;
  real<lower = 0.0> prior_scale_beta0;
  real prior_location_diff;
  real<lower = 0.0> prior_scale_diff;
  real prior_location_eta0;
  real<lower = 0.0> prior_scale_eta0;
}
/*----------------------- Parameters --------------------------*/
parameters {
  simplex[3] pi;                                                     // unit complex specifying that the sum of its elements equal to one.
  real beta0;                                                        // global intercept
  real<lower = 0.0> sigma_tot;                                       // Total standard deviation
  vector[n_age] alpha;                                               // Age intercepts
  vector[n_site] delta;                                              // Site intercepts
  real<lower = 0> difference;                                        // difference between beta_1 and beta_2
  real<lower = 0.0> sigma_cultivar;
  real eta0;
  vector[n_cultivar] eta;
}
/*------------------- Transformed Parameters --------------------*/
transformed parameters {
  real sigma;                                                        // Residual standard deviation
  real sigma_age;                                                    // Standard deviation of the age intercepts
  real sigma_site;                                                   // Standard deviation of the site intercepts
  vector[n_obs] mu[2];                                               // linear predictor
  vector[n_obs] p;                                                   // proba (early or late flowering)
  vector[2] beta;
  beta[1] = prior_location_beta0 + beta0 * prior_scale_beta0;
  beta[2] = beta[1] + difference;
  sigma = sqrt(pi[1]) * sigma_tot;
  sigma_age = sqrt(pi[2]) * sigma_tot;
  sigma_site = sqrt(pi[3]) * sigma_tot;
  mu[1] = beta[1] + alpha[AGE] + delta[SITE];
  mu[2] = beta[2] + alpha[AGE] + delta[SITE];
  p = inv_logit(rep_vector(eta0, n_obs) + eta[CULTIVAR]);
}
/*----------------------- Model --------------------------*/
model {
  // Priors
  sigma_tot ~ student_t(3, 0.0, 1.0);
  sigma_cultivar ~ student_t(3, 0.0, 1.0);
  beta0 ~ normal(0.0, 1.0);
  difference ~ normal(prior_location_diff, prior_scale_diff);
  alpha ~ normal(0.0, sigma_age);
  delta ~ normal(0.0, sigma_site);
  eta0 ~ normal(prior_location_eta0, prior_scale_eta0);
  eta ~ normal(0.0, sigma_cultivar);

  // Likelihood
  for(i in 1:n_obs) {
    target += log_sum_exp(log1m(p[i]) + normal_lpdf(FLOWERING[i] | mu[1, i], sigma),
    log(p[i]) + normal_lpdf(FLOWERING[i] | mu[2, i], sigma));
  }
}
/*----------------- Generated Quantities ------------------*/
generated quantities {
  vector[n_obs] log_lik;                                             // Log-likelihood
  vector[n_obs] y_rep;                                               // posterior predictive check

  for(i in 1:n_obs) {
    log_lik[i] = log_sum_exp(log1m(p[i]) + normal_lpdf(FLOWERING[i] | mu[1, i], sigma),
    log(p[i]) + normal_lpdf(FLOWERING[i] | mu[2, i], sigma));
    y_rep[i] = TwoGaussianMixture_rng(p[i], to_vector(mu[1:2, i]), sigma);
  }
}

4.3 Running the model

Input data:

listMixCultivar.stan = list(n_obs = nrow(dataSakura),
                 n_age = length(unique(dataSakura$Age)),
                 n_site = length(unique(dataSakura$Site)),
                 n_cultivar = length(unique(dataSakura$Cultivar)),
                 FLOWERING = dataSakura$Flowering,
                 AGE = dataSakura$Age,
                 SITE = as.numeric(factor(dataSakura$Site, levels = unique(dataSakura$Site))),
                 CULTIVAR = as.numeric(factor(dataSakura$Cultivar, levels = unique(dataSakura$Cultivar))),
                 prior_location_beta0 = mean(dataSakura$Flowering),
                 prior_scale_beta0 = 10,
                 prior_location_diff = 7,
                 prior_scale_diff = 3,
                 prior_location_eta0 = 0.0,
                 prior_scale_eta0 = 1.5)

Sampling:

fit.mixCultivar.model <- sampling(mixCultivar.code, 
                               data = listMixCultivar.stan, 
                               pars = c("eta0", "eta", "sigma_cultivar", 
                                        "beta", "alpha", "delta", 
                                        "sigma", "sigma_age", "sigma_site", "sigma_tot", 
                                        "pi", "difference", "log_lik"),
                               save_warmup = F, 
                               iter = 2000, 
                               chains = 4, 
                               cores = 4,
                               thin=1)
## Warning: There were 1 chains where the estimated Bayesian Fraction of Missing Information was low. See
## http://mc-stan.org/misc/warnings.html#bfmi-low
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess

4.4 Model outputs

Checking parameter convergence:

stan_rhat(fit.mixCultivar.model)

Parameter estimations:

print(fit.mixCultivar.model, digits = 3, pars = c("eta0", "sigma_cultivar", "beta", "sigma", "sigma_age", "sigma_site", "sigma_tot", "pi", "difference"))
## Inference for Stan model: fcabd7671bcaec1ff44fa109a42d38f4.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##                  mean se_mean    sd   2.5%    25%    50%    75%   97.5% n_eff  Rhat
## eta0           -0.212   0.012 0.439 -1.086 -0.495 -0.211  0.066   0.649  1326 1.003
## sigma_cultivar  2.693   0.053 0.705  1.714  2.225  2.575  3.014   4.449   178 1.024
## beta[1]        88.213   0.060 1.666 84.906 87.182 88.176 89.199  91.625   761 1.003
## beta[2]        97.633   0.062 1.677 94.295 96.630 97.630 98.628 101.020   725 1.005
## sigma           5.292   0.011 0.235  4.860  5.127  5.282  5.441   5.783   483 1.010
## sigma_age       1.583   0.017 0.585  0.735  1.166  1.485  1.899   2.991  1141 1.005
## sigma_site      3.395   0.035 1.125  1.844  2.586  3.185  3.961   6.070  1045 1.001
## sigma_tot       6.571   0.025 0.721  5.607  6.081  6.425  6.908   8.332   829 1.004
## pi[1]           0.666   0.004 0.115  0.402  0.597  0.681  0.751   0.847  1057 1.002
## pi[2]           0.065   0.001 0.046  0.013  0.032  0.053  0.083   0.186  1133 1.004
## pi[3]           0.269   0.004 0.119  0.101  0.180  0.248  0.338   0.545  1103 1.001
## difference      9.420   0.024 0.563  8.224  9.062  9.449  9.802  10.442   545 1.010
## 
## Samples were drawn using NUTS(diag_e) at Sat Jun  5 07:45:20 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

Comparison of the estimation of the log-likelihood between the two mix models (w or w/o cultivar variables):

# LOO-CV
loo.mixCultivar.model <- loo(fit.mixCultivar.model)
## Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.
loo::loo_compare(x=list(loo.baseline.model,loo.mixTarget.model,loo.mixCultivar.model))
##        elpd_diff se_diff
## model3   0.0       0.0  
## model2 -91.5      11.7  
## model1 -93.3      12.3

Finally, we can display the correlation between cultivars and the probability of late flowering:

proba_late <- plogis(matrix(rep(rstan::extract(fit.mixCultivar.model, "eta0")$eta0, each = length(unique(dataSakura$Cultivar))), byrow = FALSE, ncol = length(unique(dataSakura$Cultivar))) + rstan::extract(fit.mixCultivar.model, "eta")$eta)

data.frame(id = unique(dataSakura$Cultivar),
                       proba = apply(proba_late, 2, mean),
                       lower = apply(proba_late, 2, lower),
                       upper = apply(proba_late, 2, upper)
                       ) %>% 
  mutate(id = factor(id, levels = id[order(proba)])) %>% 
  ggplot(aes(x = id, y = proba)) +
  geom_linerange(aes(x = id, ymin = lower, ymax = upper)) +
  geom_point() +
  xlab("cultivar") + ylab("Pr(Late Flowering)") +
  coord_flip() +
  theme(axis.text.y = element_text(size = 6))

5 Baseline model without the simplex

By curiosity, we would like to see whether not using the simplex changes the estimates of \(\sigma\), \(\sigma_{age}\) and \(\sigma_{site}\). So, we redo the baseline model but without the simplex.

baseline.model.nosimplex <- stan_model("BaselineModelCode_NoSimplex.stan")
data {
  int<lower = 1> n_obs;                                              // Total number of observations
  int<lower = 1> n_age;                                              // Number of different age classes
  int<lower = 1> n_site;                                             // Number of different sites
  vector[n_obs] FLOWERING;                                           // Response variable (flowering dates)
  int<lower = 1, upper = n_age> AGE[n_obs];                          // Age variable
  int<lower = 1, upper = n_site> SITE[n_obs];                        // Site variable
  real prior_location_beta0;
  real<lower = 0.0> prior_scale_beta0;
}

parameters {
  real beta0;                                                        // global intercept
  vector[n_age] alpha;                                               // Age intercepts
  vector[n_site] delta;                                              // Site intercepts
  real sigma;                                                        // Residual standard deviation
  real sigma_age;                                                    // Standard deviation of the age intercepts
  real sigma_site;                                                   // Standard deviation of the site intercepts
}

transformed parameters {
  vector[n_obs] mu;                                                  // linear predictor
  
  mu = rep_vector(beta0, n_obs) + alpha[AGE] + delta[SITE];
}

model {
  // Priors
  beta0 ~ normal(prior_location_beta0, prior_scale_beta0);           // Prior of the global intercept
  alpha ~ normal(0.0, sigma_age);                                    // Prior of the age intercepts
  delta ~ normal(0.0, sigma_site);                                   // Prior of the site intercepts
  sigma ~ exponential(1);
  sigma_age ~ exponential(1);
  sigma_site ~ exponential(1);
  
  // Likelihood
  FLOWERING ~ normal(mu, sigma);
}

generated quantities {
  vector[n_obs] log_lik;                                             // Log-likelihood
  vector[n_obs] y_rep;                                               // posterior predictive check
  
  for(i in 1:n_obs) {
    log_lik[i] = normal_lpdf(FLOWERING[i]| mu[i], sigma);
    y_rep[i] = normal_rng(mu[i], sigma);
  }
}
fit.baseline.model.nosimplex <- sampling(baseline.model.nosimplex, 
                               data = list.baseline.model, 
                               pars = c("beta0", "alpha", "delta", 
                                        "sigma", "sigma_age", "sigma_site",
                                        "y_rep", "log_lik"),
                               save_warmup = F, 
                               iter = 2000, 
                               chains = 4, cores = 4,thin=1)
## Warning: There were 36 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
print(fit.baseline.model.nosimplex, 
      digits = 3, pars = c("beta0", "sigma", "sigma_age", "sigma_site"))
## Inference for Stan model: 19b46d68dfd80e5d0f57ac5fbd3de601.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##              mean se_mean    sd   2.5%    25%    50%    75%  97.5% n_eff  Rhat
## beta0      92.785   0.055 1.471 89.983 91.868 92.801 93.715 95.675   717 1.007
## sigma       7.089   0.003 0.167  6.770  6.973  7.089  7.198  7.421  3253 1.002
## sigma_age   1.029   0.021 0.499  0.267  0.674  0.957  1.311  2.123   542 1.001
## sigma_site  2.933   0.022 0.859  1.695  2.322  2.803  3.369  5.022  1466 1.000
## 
## Samples were drawn using NUTS(diag_e) at Sat Jun  5 07:46:00 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

There are more warnings in this model, compared to the baseline model with a simplex.

list(baseline.model=fit.baseline.model,baseline.model.nosimplex=fit.baseline.model.nosimplex) %>% 
  mclapply(function(x) {
    broom.mixed::tidyMCMC(x,pars=c("sigma","sigma_site","sigma_age"),
                droppars = NULL, estimate.method = "median", 
                ess = F, rhat = F, 
                conf.int = T,conf.level = 0.95)}) %>%  
  bind_rows(.id="model") %>% 
  ggplot(aes(x = term, y = estimate,ymin = conf.low, ymax = conf.high,color=model)) +
  geom_pointinterval(position = position_dodge(width = .8),point_size=5,alpha=0.6,size=8) +
  xlab("") +
  ylab("Standard deviation estimates") +
  labs(color = "Models") +
  theme(axis.text = element_text(size=20),
        panel.grid.minor.x=element_blank(),
        panel.grid.major.x=element_blank())

# Baseline model with the simplex
np <- nuts_params(fit.baseline.model)
mcmc_pairs(as.array(fit.baseline.model),
           np = np,
           pars = c("sigma","sigma_site","sigma_age"),
           off_diag_args = list(size = 1, alpha = 1/3),
           np_style = pairs_style_np(div_size=1, div_shape = 19),
           max_treedepth = 10)

# Baseline model without the simplex
np <- nuts_params(fit.baseline.model.nosimplex)
mcmc_pairs(as.array(fit.baseline.model.nosimplex),
           np = np,
           pars = c("sigma","sigma_site","sigma_age"),
           off_diag_args = list(size = 1, alpha = 1/3),
           np_style = pairs_style_np(div_size=1, div_shape = 19),
           max_treedepth = 10)

References

Wenden, Bénédicte, José Antonio Campoy, Julien Lecourt, Gregorio López Ortega, Michael Blanke, Sanja Radičević, Elisabeth Schüller, et al. 2016. “A Collection of European Sweet Cherry Phenology Data for Assessing Climate Change.” Scientific Data 3 (1). Nature Publishing Group: 1–10.